package controller

import (
	"fmt"
	"io"
	"os/exec"
	"path/filepath"
	"strings"

	"github.com/infraboard/mcube/logger"
	"github.com/infraboard/mcube/logger/zap"
)

func NewScriptCollector(homeDir string) *ScriptCollector {
	return &ScriptCollector{
		homeDir: homeDir,
		log:     zap.L().Named("script"),
	}
}

type ScriptCollector struct {
	// 脚本存放的目录
	homeDir string
	
	log logger.Logger
}

func (c *ScriptCollector) Exec(module, params string, dst io.Writer) error {
	// 通过module找到脚本存放位置
	script, err := c.find(module)
	if err != nil {
		return err
	}
	c.log.Debugf("exec script: %s", script)

	if script == "" {
		return fmt.Errorf("module %s not found", module)
	}

	// 根据脚本的扩展名, 来决定如何执行
	var cmd *exec.Cmd
	ext := filepath.Ext(script)
	switch ext {
	case ".sh":
		cmd = exec.Command("bash", script, params)
	case ".py":
		cmd = exec.Command("python", script, params)
	default:
		cmd = exec.Command(script, params)
	}

	// 获取命令的输出, 以流的方式
	std, err := cmd.StdoutPipe()
	if err != nil {
		return err
	}
	defer std.Close()

	// 执行命令, 后台执行 go cmd.Run
	if err := cmd.Start(); err != nil {
		return err
	}

	// 把stream copy出去
	_, err = io.Copy(dst, std)
	if err != nil {
		return err
	}

	return nil
}

func (c *ScriptCollector) find(module string) (string, error) {
	absPath, err := filepath.Abs(c.homeDir)
	if err != nil {
		return "", fmt.Errorf("find module %s abs path error %s", module, err)
	}

	// 防止用户传入的执行脚本 超出指定目录
	if strings.Contains(module, "..") {
		return "", fmt.Errorf("module forbiden .. in module")
	}

	return filepath.Join(absPath, module), nil
}
